{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!wget 'https://drive.usercontent.google.com/download?id=1hnxfzZiJUxWp4g56k6NHZA8WuAGpNlLD&export=download&authuser=0&confirm=t&uuid=681448e2-75cd-4dcc-93c7-d07db1c56e0a&at=APZUnTVNwXiq0HTB_MSar4ZosNbt%3A1720868805199' \\\n",
        "  -H 'accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7' \\\n",
        "  -H 'accept-language: zh-CN,zh;q=0.9,en;q=0.8' \\\n",
        "  -H 'cookie: __Secure-ENID=19.SE=oA62K4hO_IKQy2-I0VY1OXJlqMeBgKRD3m35tYX6g0bBMNtbC7LTdZ41J2rPSjl38z6V3p_wZ-w3fzKAVL_O6TghYtRV4CWBCZaIsFomcQJByyDziWbyUeQf4aV5_H6gDwPWzfIHfuEKk_OK6ONhogFPdDv1Ytvl9SVxUuSTOSnOD9DsfgqtKFku4kekYo5Dz1c5ZL2-rNogasAkEAFZu1TiEmdAFMwE_o0EX7sS2aAH3Y_zCyj3LNGvGPNRMtjHhUZUIZYn_6yF44o6R05YAdAWEjVTtBqmRA; HSID=Au5mEfQIEnswkkbO1; SSID=Afndux5i0KqYMyo4z; APISID=zemmGWVbB4AumMTg/AfMedvB-y9s_DBqhe; SAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls; __Secure-1PAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls; __Secure-3PAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls; SID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UhpPpXtVB_v8Caw-Y5l9k-wACgYKAakSARASFQHGX2MiE8MK3QF-g3tZnjX_3OOs_hoVAUF8yKoM5iUgJQWH5i6IaU5EWwjr0076; __Secure-1PSID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UQF0XHpB_aI9dGL2r_lgt1QACgYKAQsSARASFQHGX2MimBAShtrmPtwPm-yRM1SHWBoVAUF8yKqzJbKTYoGvZaYNOjoZ1xLe0076; __Secure-3PSID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UrKmRIynAH1zWSq5cxQvc0AACgYKAaUSARASFQHGX2MiDpCQ2mcrTePdrZC6GbuESBoVAUF8yKqUhRRppV2m3AHtnEBaAFgm0076; AEC=AVYB7cqDdwGSYREKkRWI-CGfNO4teIUDTwojHxpwCKqC2JkgccKnH1q9GA; NID=515=EgU4Nwq3DwXQChQ9V-Cqp7AiX5HOETC2Egc1nCcmSHRaqhXgr1mwl23ivWpZR-GU1X0jr8y-AEWwyE0I-eIdqeNSajkOGYKpzsGktRgxrtbZmjMY32FEHNTU2RPy0nttoZeQpuDfMF53gz4qYpOoPDmVd6IOZqFxrfi_Xgkw13jgEyRxAqkdz6maIV6gIj4XRbteQYtjwyqL5vPwMdvYXMucFKJDV7YVzwwlyY1vwqsJDC7dEj3PG6D5GQJee_ADug4XhgkW-nlVXLx9HreYk-26e8CqS_iAZBPnAfyfZnJdFJZK3Asy0gdboB_wCBsNrbHf9Nyt-jIGA8TTc9c-_1kNufbnrnfwnNAuld7huYQelKmEk614dP8IHbZpH0dOSxBhB3ZnA8mCct7HXHjkK1TpGLvL7Zlj3JFv3e3lY1A; __Secure-1PSIDTS=sidts-CjEB4E2dkTjzNo0EccYIS2eVnNnFKwX50_RL40x04P2ouKOAKUDMtjlAKcmsHRfh0joCEAA; __Secure-3PSIDTS=sidts-CjEB4E2dkTjzNo0EccYIS2eVnNnFKwX50_RL40x04P2ouKOAKUDMtjlAKcmsHRfh0joCEAA; SIDCC=AKEyXzVwnlMFRr3fan41lxL7DK9cC5NFttrQcPOy00CN5u194xkxCSWIWWzZsxnUWFJyyv_r_Jg; __Secure-1PSIDCC=AKEyXzWZKAEL6f408b67BAAzYSMith4LF5nFa6AYkzHzwssR5wLLhSD67BIUTCVLhT_cisflERg; __Secure-3PSIDCC=AKEyXzVowEQWVbSpER1PG9OSzR1UDGI7IpdAF14b4fTh3QRYXCC6Ep20LsSk-mtuwWtJRGyV49Y' \\\n",
        "  -H 'priority: u=0, i' \\\n",
        "  -H 'sec-ch-ua: \"Not/A)Brand\";v=\"8\", \"Chromium\";v=\"126\", \"Google Chrome\";v=\"126\"' \\\n",
        "  -H 'sec-ch-ua-arch: \"x86\"' \\\n",
        "  -H 'sec-ch-ua-bitness: \"64\"' \\\n",
        "  -H 'sec-ch-ua-form-factors: \"Desktop\"' \\\n",
        "  -H 'sec-ch-ua-full-version: \"126.0.6478.127\"' \\\n",
        "  -H 'sec-ch-ua-full-version-list: \"Not/A)Brand\";v=\"8.0.0.0\", \"Chromium\";v=\"126.0.6478.127\", \"Google Chrome\";v=\"126.0.6478.127\"' \\\n",
        "  -H 'sec-ch-ua-mobile: ?0' \\\n",
        "  -H 'sec-ch-ua-model: \"\"' \\\n",
        "  -H 'sec-ch-ua-platform: \"Windows\"' \\\n",
        "  -H 'sec-ch-ua-platform-version: \"15.0.0\"' \\\n",
        "  -H 'sec-ch-ua-wow64: ?0' \\\n",
        "  -H 'sec-fetch-dest: document' \\\n",
        "  -H 'sec-fetch-mode: navigate' \\\n",
        "  -H 'sec-fetch-site: cross-site' \\\n",
        "  -H 'sec-fetch-user: ?1' \\\n",
        "  -H 'upgrade-insecure-requests: 1' \\\n",
        "  -H 'user-agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36' \\\n",
        "  -H 'x-client-data: CKm1yQEIhbbJAQijtskBCKmdygEIoobLAQiVocsBCIWgzQEIpqLOARiPzs0B' -c -O 'trackA1.zip'"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iE9aOPvayuTx",
        "outputId": "fdb6d1b3-4145-4dd9-d48a-2ff37b894781"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2024-07-13 11:14:40--  https://drive.usercontent.google.com/download?id=1hnxfzZiJUxWp4g56k6NHZA8WuAGpNlLD&export=download&authuser=0&confirm=t&uuid=681448e2-75cd-4dcc-93c7-d07db1c56e0a&at=APZUnTVNwXiq0HTB_MSar4ZosNbt%3A1720868805199\n",
            "Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 108.177.98.132, 2607:f8b0:400e:c06::84\n",
            "Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|108.177.98.132|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 845377049 (806M) [application/octet-stream]\n",
            "Saving to: ‘trackA1.zip’\n",
            "\n",
            "trackA1.zip         100%[===================>] 806.21M  62.0MB/s    in 15s     \n",
            "\n",
            "2024-07-13 11:14:57 (54.1 MB/s) - ‘trackA1.zip’ saved [845377049/845377049]\n",
            "\n",
            "--2024-07-13 11:14:57--  ftp://accept/%20text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7\n",
            "           => ‘trackA1.zip’\n",
            "Resolving accept (accept)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘accept’\n",
            "--2024-07-13 11:14:57--  ftp://accept-language/%20zh-CN,zh;q=0.9,en;q=0.8\n",
            "           => ‘trackA1.zip’\n",
            "Resolving accept-language (accept-language)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘accept-language’\n",
            "The destination name is too long (311), reducing to 236\n",
            "--2024-07-13 11:14:57--  ftp://cookie/%20__Secure-ENID=19.SE=oA62K4hO_IKQy2-I0VY1OXJlqMeBgKRD3m35tYX6g0bBMNtbC7LTdZ41J2rPSjl38z6V3p_wZ-w3fzKAVL_O6TghYtRV4CWBCZaIsFomcQJByyDziWbyUeQf4aV5_H6gDwPWzfIHfuEKk_OK6ONhogFPdDv1Ytvl9SVxUuSTOSnOD9DsfgqtKFku4kekYo5Dz1c5ZL2-rNogasAkEAFZu1TiEmdAFMwE_o0EX7sS2aAH3Y_zCyj3LNGvGPNRMtjHhUZUIZYn_6yF44o6R05YAdAWEjVTtBqmRA;%20HSID=Au5mEfQIEnswkkbO1;%20SSID=Afndux5i0KqYMyo4z;%20APISID=zemmGWVbB4AumMTg/AfMedvB-y9s_DBqhe;%20SAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls;%20__Secure-1PAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls;%20__Secure-3PAPISID=-gLhO6LPVNTXcFKl/A448fx66VtkLhNkls;%20SID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UhpPpXtVB_v8Caw-Y5l9k-wACgYKAakSARASFQHGX2MiE8MK3QF-g3tZnjX_3OOs_hoVAUF8yKoM5iUgJQWH5i6IaU5EWwjr0076;%20__Secure-1PSID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UQF0XHpB_aI9dGL2r_lgt1QACgYKAQsSARASFQHGX2MimBAShtrmPtwPm-yRM1SHWBoVAUF8yKqzJbKTYoGvZaYNOjoZ1xLe0076;%20__Secure-3PSID=g.a000lQhycEKOvnaJsMLOAnPtRuGvNiv-Kpt1zFjSmDalWKI2BZ2UrKmRIynAH1zWSq5cxQvc0AACgYKAaUSARASFQHGX2MiDpCQ2mcrTePdrZC6GbuESBoVAUF8yKqUhRRppV2m3AHtnEBaAFgm0076;%20AEC=AVYB7cqDdwGSYREKkRWI-CGfNO4teIUDTwojHxpwCKqC2JkgccKnH1q9GA;%20NID=515=EgU4Nwq3DwXQChQ9V-Cqp7AiX5HOETC2Egc1nCcmSHRaqhXgr1mwl23ivWpZR-GU1X0jr8y-AEWwyE0I-eIdqeNSajkOGYKpzsGktRgxrtbZmjMY32FEHNTU2RPy0nttoZeQpuDfMF53gz4qYpOoPDmVd6IOZqFxrfi_Xgkw13jgEyRxAqkdz6maIV6gIj4XRbteQYtjwyqL5vPwMdvYXMucFKJDV7YVzwwlyY1vwqsJDC7dEj3PG6D5GQJee_ADug4XhgkW-nlVXLx9HreYk-26e8CqS_iAZBPnAfyfZnJdFJZK3Asy0gdboB_wCBsNrbHf9Nyt-jIGA8TTc9c-_1kNufbnrnfwnNAuld7huYQelKmEk614dP8IHbZpH0dOSxBhB3ZnA8mCct7HXHjkK1TpGLvL7Zlj3JFv3e3lY1A;%20__Secure-1PSIDTS=sidts-CjEB4E2dkTjzNo0EccYIS2eVnNnFKwX50_RL40x04P2ouKOAKUDMtjlAKcmsHRfh0joCEAA;%20__Secure-3PSIDTS=sidts-CjEB4E2dkTjzNo0EccYIS2eVnNnFKwX50_RL40x04P2ouKOAKUDMtjlAKcmsHRfh0joCEAA;%20SIDCC=AKEyXzVwnlMFRr3fan41lxL7DK9cC5NFttrQcPOy00CN5u194xkxCSWIWWzZsxnUWFJyyv_r_Jg;%20__Secure-1PSIDCC=AKEyXzWZKAEL6f408b67BAAzYSMith4LF5nFa6AYkzHzwssR5wLLhSD67BIUTCVLhT_cisflERg;%20__Secure-3PSIDCC=AKEyXzVowEQWVbSpER1PG9OSzR1UDGI7IpdAF14b4fTh3QRYXCC6Ep20LsSk-mtuwWtJRGyV49Y\n",
            "           => ‘trackA1.zip’\n",
            "Resolving cookie (cookie)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘cookie’\n",
            "--2024-07-13 11:14:57--  ftp://priority/%20u=0,%20i\n",
            "           => ‘trackA1.zip’\n",
            "Resolving priority (priority)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘priority’\n",
            "--2024-07-13 11:14:57--  ftp://sec-ch-ua/%20%22Not/A)Brand%22;v=%228%22,%20%22Chromium%22;v=%22126%22,%20%22Google%20Chrome%22;v=%22126%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua (sec-ch-ua)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua’\n",
            "--2024-07-13 11:14:57--  ftp://sec-ch-ua-arch/%20%22x86%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-arch (sec-ch-ua-arch)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-arch’\n",
            "--2024-07-13 11:14:57--  ftp://sec-ch-ua-bitness/%20%2264%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-bitness (sec-ch-ua-bitness)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-bitness’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-form-factors/%20%22Desktop%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-form-factors (sec-ch-ua-form-factors)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-form-factors’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-full-version/%20%22126.0.6478.127%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-full-version (sec-ch-ua-full-version)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-full-version’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-full-version-list/%20%22Not/A)Brand%22;v=%228.0.0.0%22,%20%22Chromium%22;v=%22126.0.6478.127%22,%20%22Google%20Chrome%22;v=%22126.0.6478.127%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-full-version-list (sec-ch-ua-full-version-list)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-full-version-list’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-mobile/%20?0\n",
            "           => ‘.listing’\n",
            "Resolving sec-ch-ua-mobile (sec-ch-ua-mobile)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-mobile’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-model/%20%22%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-model (sec-ch-ua-model)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-model’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-platform/%20%22Windows%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-platform (sec-ch-ua-platform)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-platform’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-platform-version/%20%2215.0.0%22\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-ch-ua-platform-version (sec-ch-ua-platform-version)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-platform-version’\n",
            "--2024-07-13 11:14:58--  ftp://sec-ch-ua-wow64/%20?0\n",
            "           => ‘.listing’\n",
            "Resolving sec-ch-ua-wow64 (sec-ch-ua-wow64)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-ch-ua-wow64’\n",
            "--2024-07-13 11:14:58--  ftp://sec-fetch-dest/%20document\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-fetch-dest (sec-fetch-dest)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-fetch-dest’\n",
            "--2024-07-13 11:14:58--  ftp://sec-fetch-mode/%20navigate\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-fetch-mode (sec-fetch-mode)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-fetch-mode’\n",
            "--2024-07-13 11:14:58--  ftp://sec-fetch-site/%20cross-site\n",
            "           => ‘trackA1.zip’\n",
            "Resolving sec-fetch-site (sec-fetch-site)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-fetch-site’\n",
            "--2024-07-13 11:14:58--  ftp://sec-fetch-user/%20?1\n",
            "           => ‘.listing’\n",
            "Resolving sec-fetch-user (sec-fetch-user)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘sec-fetch-user’\n",
            "--2024-07-13 11:14:58--  ftp://upgrade-insecure-requests/%201\n",
            "           => ‘trackA1.zip’\n",
            "Resolving upgrade-insecure-requests (upgrade-insecure-requests)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘upgrade-insecure-requests’\n",
            "--2024-07-13 11:14:58--  ftp://user-agent/%20Mozilla/5.0%20(Windows%20NT%2010.0;%20Win64;%20x64)%20AppleWebKit/537.36%20(KHTML,%20like%20Gecko)%20Chrome/126.0.0.0%20Safari/537.36\n",
            "           => ‘trackA1.zip’\n",
            "Resolving user-agent (user-agent)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘user-agent’\n",
            "--2024-07-13 11:14:58--  ftp://x-client-data/%20CKm1yQEIhbbJAQijtskBCKmdygEIoobLAQiVocsBCIWgzQEIpqLOARiPzs0B\n",
            "           => ‘trackA1.zip’\n",
            "Resolving x-client-data (x-client-data)... failed: Name or service not known.\n",
            "wget: unable to resolve host address ‘x-client-data’\n",
            "FINISHED --2024-07-13 11:14:58--\n",
            "Total wall clock time: 18s\n",
            "Downloaded: 1 files, 806M in 15s (54.1 MB/s)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!unzip trackA1.zip\n",
        "!mkdir infer_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CpfUVDGm2j9X",
        "outputId": "4beb88b4-02c5-4b8f-a6af-de6324e03b1e"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Archive:  trackA1.zip\n",
            "   creating: trackA0612/ckpt/\n",
            "  inflating: trackA0612/ckpt/config.yaml  \n",
            "  inflating: trackA0612/ckpt/epoch=189-step=39900.ckpt  \n",
            "   creating: trackA0612/data_transform/\n",
            "   creating: trackA0612/data_transform/__pycache__/\n",
            "  inflating: trackA0612/data_transform/__pycache__/transform.cpython-39.pyc  \n",
            "  inflating: trackA0612/data_transform/transform.py  \n",
            "   creating: trackA0612/dataset/\n",
            "  inflating: trackA0612/dataset/v45_val_feature.npy  \n",
            "   creating: trackA0612/networks/\n",
            "  inflating: trackA0612/networks/__init__.py  \n",
            "   creating: trackA0612/networks/common/\n",
            "   creating: trackA0612/networks/common/__pycache__/\n",
            "  inflating: trackA0612/networks/common/__pycache__/loss.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/common/__pycache__/loss_pk.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/common/__pycache__/loss_pw.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/common/__pycache__/normalization.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/common/LBase.py  \n",
            "  inflating: trackA0612/networks/common/loss.py  \n",
            "  inflating: trackA0612/networks/common/loss_pk.py  \n",
            "  inflating: trackA0612/networks/common/loss_pw.py  \n",
            "  inflating: trackA0612/networks/common/normalization.py  \n",
            "   creating: trackA0612/networks/Solver/\n",
            "   creating: trackA0612/networks/Solver/__pycache__/\n",
            "  inflating: trackA0612/networks/Solver/__pycache__/switch.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/Solver/__pycache__/transLinear.cpython-39.pyc  \n",
            "  inflating: trackA0612/networks/Solver/switch.py  \n",
            "  inflating: trackA0612/networks/Solver/transLinear.py  \n",
            "   creating: trackA0612/utils/\n",
            "  inflating: trackA0612/utils/__init__.py  \n",
            "  inflating: trackA0612/utils/modules.py  \n",
            "  inflating: trackA0612/utils/registry.py  \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install einops\n",
        "!pip install lightning\n",
        "!pip install wandb"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "RAae930Q1yNj",
        "outputId": "6654f35e-a09d-44e7-f02a-74154ed03aff"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting einops\n",
            "  Downloading einops-0.8.0-py3-none-any.whl (43 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: einops\n",
            "Successfully installed einops-0.8.0\n",
            "Collecting lightning\n",
            "  Downloading lightning-2.3.3-py3-none-any.whl (808 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m808.5/808.5 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.10/dist-packages (from lightning) (6.0.1)\n",
            "Requirement already satisfied: fsspec[http]<2026.0,>=2022.5.0 in /usr/local/lib/python3.10/dist-packages (from lightning) (2023.6.0)\n",
            "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)\n",
            "  Downloading lightning_utilities-0.11.3.post0-py3-none-any.whl (26 kB)\n",
            "Requirement already satisfied: numpy<3.0,>=1.17.2 in /usr/local/lib/python3.10/dist-packages (from lightning) (1.25.2)\n",
            "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.10/dist-packages (from lightning) (24.1)\n",
            "Requirement already satisfied: torch<4.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from lightning) (2.3.0+cu121)\n",
            "Collecting torchmetrics<3.0,>=0.7.0 (from lightning)\n",
            "  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m868.8/868.8 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.10/dist-packages (from lightning) (4.66.4)\n",
            "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.10/dist-packages (from lightning) (4.12.2)\n",
            "Collecting pytorch-lightning (from lightning)\n",
            "  Downloading pytorch_lightning-2.3.3-py3-none-any.whl (812 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m812.3/812.3 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning) (2.31.0)\n",
            "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning) (3.9.5)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning) (67.7.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch<4.0,>=2.0.0->lightning) (3.15.4)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch<4.0,>=2.0.0->lightning) (1.13.0)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch<4.0,>=2.0.0->lightning) (3.3)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch<4.0,>=2.0.0->lightning) (3.1.4)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
            "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "Collecting nvidia-curand-cu12==10.3.2.106 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "Collecting nvidia-nccl-cu12==2.20.5 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
            "Collecting nvidia-nvtx-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)\n",
            "  Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
            "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch<4.0,>=2.0.0->lightning) (2.3.0)\n",
            "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch<4.0,>=2.0.0->lightning)\n",
            "  Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m56.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.3.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (23.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.4.1)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (6.0.5)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (1.9.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning) (4.0.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch<4.0,>=2.0.0->lightning) (2.1.5)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]<2026.0,>=2022.5.0->lightning) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]<2026.0,>=2022.5.0->lightning) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]<2026.0,>=2022.5.0->lightning) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]<2026.0,>=2022.5.0->lightning) (2024.7.4)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch<4.0,>=2.0.0->lightning) (1.3.0)\n",
            "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning\n",
            "Successfully installed lightning-2.3.3 lightning-utilities-0.11.3.post0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105 pytorch-lightning-2.3.3 torchmetrics-1.4.0.post0\n",
            "Collecting wandb\n",
            "  Downloading wandb-0.17.4-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.9 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n",
            "Collecting docker-pycreds>=0.4.0 (from wandb)\n",
            "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
            "Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)\n",
            "  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb) (4.2.2)\n",
            "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n",
            "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)\n",
            "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n",
            "Collecting sentry-sdk>=1.0.0 (from wandb)\n",
            "  Downloading sentry_sdk-2.9.0-py2.py3-none-any.whl (301 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.8/301.8 kB\u001b[0m \u001b[31m33.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting setproctitle (from wandb)\n",
            "  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n",
            "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
            "Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)\n",
            "  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.7.4)\n",
            "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)\n",
            "  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
            "Installing collected packages: smmap, setproctitle, sentry-sdk, docker-pycreds, gitdb, gitpython, wandb\n",
            "Successfully installed docker-pycreds-0.4.0 gitdb-4.0.11 gitpython-3.1.43 sentry-sdk-2.9.0 setproctitle-1.3.3 smmap-5.0.1 wandb-0.17.4\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import sys\n",
        "sys.path.append('trackA0612/')\n",
        "\n",
        "import os\n",
        "import yaml\n",
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "from networks import *\n",
        "\n",
        "from utils import class_builder\n",
        "from data_transform.transform import UnitGaussianNormalizer\n",
        "import torch.nn as nn\n",
        "import pytorch_lightning as pl"
      ],
      "metadata": {
        "id": "A3QSDOXF1mRQ"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def load_config(config_path):\n",
        "  with open(config_path, 'r') as file:\n",
        "    return yaml.safe_load(file)\n",
        "\n",
        "class MixNormalizer_8in(nn.Module):\n",
        "  def __init__(self,\n",
        "        pos_min,\n",
        "        pos_max,\n",
        "        feature_mean,\n",
        "        feature_std,\n",
        "        eps=1e-06):\n",
        "    '''\n",
        "    Normalizer for 8 in_channels, pos dimension should be 3, feature dimension should be 5\n",
        "    '''\n",
        "    super().__init__()\n",
        "    self.pos_min = torch.tensor(pos_min)\n",
        "    self.pos_max = torch.tensor(pos_max)\n",
        "    self.feature_mean = torch.tensor(feature_mean)\n",
        "    self.feature_std = torch.tensor(feature_std)\n",
        "    self.eps = torch.tensor(eps)\n",
        "\n",
        "  def encode(self, x):\n",
        "    device = x.device\n",
        "    x[..., :3] = (x[..., :3] - self.pos_min.to(device)) / (self.pos_max.to(device) - self.pos_min.to(device) + self.eps.to(device))\n",
        "    x[..., 3:] = (x[..., 3:] - self.feature_mean.to(device)) / (self.feature_std.to(device) + self.eps.to(device))\n",
        "    return x\n",
        "\n",
        "  def decode(self, x):\n",
        "    device = x.device\n",
        "    x[..., :3] = x[..., :3] * (self.pos_max.to(device) - self.pos_min.to(device) + self.eps) + self.pos_min.to(device)\n",
        "    x[..., 3:] = x[..., 3:] * (self.feature_std.to(device) + self.eps.to(device)) + self.feature_mean.to(device)\n",
        "    return x\n",
        "\n",
        "class TrainingDataModule8in(pl.LightningDataModule):\n",
        "  def __init__(self, ):\n",
        "    super().__init__()\n",
        "  def setup(self):\n",
        "    self.setup_normalizer()\n",
        "  def setup_normalizer(self):\n",
        "    self.data_normalizer_x = MixNormalizer_8in(\n",
        "      pos_min=np.array([-0.902528, 0.0066688,-2.85272]),\n",
        "      pos_max=np.array([0.90253, 1.99982, 2.86385]),\n",
        "      feature_mean=np.array([1.4136816e+00,\n",
        "                              6.6241097e-05,\n",
        "                              -8.8731814e-03,\n",
        "                              -2.1077667e-03,\n",
        "                              6.3469582e-03]),\n",
        "      feature_std=np.array([5.3759489e+00,\n",
        "                            6.4974260e-01,\n",
        "                            6.6663116e-01,\n",
        "                            3.6476558e-01,\n",
        "                            2.9094261e-03]),\n",
        "    )\n",
        "\n",
        "    self.data_normalizer_y = UnitGaussianNormalizer(\n",
        "      mean=np.array([-36.239616]),\n",
        "      std=np.array([48.501846])\n",
        "    )\n",
        "    self.normalizer_x = self.data_normalizer_x\n",
        "    self.normalizer_y = self.data_normalizer_y\n",
        ""
      ],
      "metadata": {
        "id": "wSgkH9HU2vGW"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ckpt_path = 'trackA0612/ckpt/epoch=189-step=39900.ckpt'\n",
        "config_path = 'trackA0612/ckpt/config.yaml'\n",
        "test_feature_path = 'trackA0612/dataset/v45_val_feature.npy'"
      ],
      "metadata": {
        "id": "ZJvacouC3RIb"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "file_name = [str(i) for i in range(658, 723)]\n",
        "remove_item = ['661', '669', '670', '671', '680', '682', '685',\n",
        "        '694', '698', '699', '706', '707', '714', '716', '720']\n",
        "for item in remove_item:\n",
        "  file_name.remove(item)\n",
        "\n",
        "len(file_name)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6J5fNBzC3g1U",
        "outputId": "8457d245-d9db-4fa6-bfc7-0f52eeef4cb3"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "50"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "config = load_config(config_path)\n",
        "data_module = TrainingDataModule8in()\n",
        "data_module.setup()\n",
        "\n",
        "model = solver.load_from_checkpoint(ckpt_path)\n",
        "# model = class_builder(config['network'],\n",
        "#             normalizer_y=data_module.normalizer_y,\n",
        "#             normalizer_x=data_module.normalizer_x,\n",
        "#             use_wandb=False).load_from_checkpoint(\n",
        "#                 ckpt_path,\n",
        "#             )\n",
        "\n",
        "test_feature = torch.from_numpy(np.load(test_feature_path)).to(torch.float32)\n",
        "feature_nor = data_module.normalizer_x.encode(test_feature)\n",
        "with torch.no_grad():\n",
        "  model.eval()\n",
        "  for i in range(feature_nor.shape[0]):\n",
        "    pre = model(feature_nor[i].unsqueeze(0).to(model.device))\n",
        "\n",
        "    pre_nor = data_module.normalizer_y.decode(pre)\n",
        "\n",
        "    print(pre_nor.shape)\n",
        "\n",
        "    file_name_tosave = 'press_'+file_name[i]+'.npy'\n",
        "    # 保存为.npy文件\n",
        "    pre_nor = pre_nor.detach().cpu().numpy()[0, :, 0]\n",
        "    np.save(os.path.join('infer_results', file_name_tosave), pre_nor)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hk7OPWPX3l9g",
        "outputId": "1c964ac9-a33f-4c3f-cdd7-911d8738891a"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "==================== Parameters ====================\n",
            "ModelArgs(input_dim=8, output_dim=1, width=512, n_heads=16, n_layers=8, n_experts=4, act='SwiGLU', base=10000, normalization='RMSnorm', attn_type='normal', norm_eps=1e-05, moe_capacity_factor=1.5, dropout_prob=0.2, lr=0.0001)\n",
            "==================================================\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n",
            "torch.Size([1, 3586, 1])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "!zip -r B_result.zip infer_results/"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Jatl19tk33jD",
        "outputId": "4bcb367a-9a97-437d-fdb1-5140b2a1ff13"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  adding: infer_results/ (stored 0%)\n",
            "  adding: infer_results/press_689.npy (deflated 13%)\n",
            "  adding: infer_results/press_663.npy (deflated 12%)\n",
            "  adding: infer_results/press_666.npy (deflated 13%)\n",
            "  adding: infer_results/press_686.npy (deflated 13%)\n",
            "  adding: infer_results/press_696.npy (deflated 12%)\n",
            "  adding: infer_results/press_717.npy (deflated 13%)\n",
            "  adding: infer_results/press_673.npy (deflated 13%)\n",
            "  adding: infer_results/press_700.npy (deflated 13%)\n",
            "  adding: infer_results/press_667.npy (deflated 12%)\n",
            "  adding: infer_results/press_668.npy (deflated 12%)\n",
            "  adding: infer_results/press_697.npy (deflated 13%)\n",
            "  adding: infer_results/press_674.npy (deflated 13%)\n",
            "  adding: infer_results/press_677.npy (deflated 13%)\n",
            "  adding: infer_results/press_688.npy (deflated 13%)\n",
            "  adding: infer_results/press_679.npy (deflated 13%)\n",
            "  adding: infer_results/press_695.npy (deflated 12%)\n",
            "  adding: infer_results/press_681.npy (deflated 13%)\n",
            "  adding: infer_results/press_721.npy (deflated 13%)\n",
            "  adding: infer_results/press_712.npy (deflated 13%)\n",
            "  adding: infer_results/press_664.npy (deflated 12%)\n",
            "  adding: infer_results/press_722.npy (deflated 13%)\n",
            "  adding: infer_results/press_675.npy (deflated 12%)\n",
            "  adding: infer_results/press_709.npy (deflated 12%)\n",
            "  adding: infer_results/press_672.npy (deflated 13%)\n",
            "  adding: infer_results/press_691.npy (deflated 13%)\n",
            "  adding: infer_results/press_676.npy (deflated 12%)\n",
            "  adding: infer_results/press_710.npy (deflated 11%)\n",
            "  adding: infer_results/press_687.npy (deflated 13%)\n",
            "  adding: infer_results/press_684.npy (deflated 13%)\n",
            "  adding: infer_results/press_702.npy (deflated 12%)\n",
            "  adding: infer_results/press_703.npy (deflated 11%)\n",
            "  adding: infer_results/press_662.npy (deflated 13%)\n",
            "  adding: infer_results/press_683.npy (deflated 13%)\n",
            "  adding: infer_results/press_718.npy (deflated 13%)\n",
            "  adding: infer_results/press_701.npy (deflated 13%)\n",
            "  adding: infer_results/press_715.npy (deflated 13%)\n",
            "  adding: infer_results/press_658.npy (deflated 13%)\n",
            "  adding: infer_results/press_708.npy (deflated 13%)\n",
            "  adding: infer_results/press_693.npy (deflated 11%)\n",
            "  adding: infer_results/press_705.npy (deflated 13%)\n",
            "  adding: infer_results/press_692.npy (deflated 13%)\n",
            "  adding: infer_results/press_704.npy (deflated 13%)\n",
            "  adding: infer_results/press_678.npy (deflated 12%)\n",
            "  adding: infer_results/press_660.npy (deflated 13%)\n",
            "  adding: infer_results/press_665.npy (deflated 11%)\n",
            "  adding: infer_results/press_659.npy (deflated 13%)\n",
            "  adding: infer_results/press_711.npy (deflated 12%)\n",
            "  adding: infer_results/press_719.npy (deflated 12%)\n",
            "  adding: infer_results/press_713.npy (deflated 12%)\n",
            "  adding: infer_results/press_690.npy (deflated 10%)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "gMLA_9TP_UfC"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}